Source code for hysop.numerics.fft.fftw_fft

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
FFT iterface for fast Fourier Transforms using FFTW backend (using pyfftw).
:class:`~hysop.numerics.FftwFFT`
:class:`~hysop.numerics.FftwFFTPlan`
"""

import warnings
import pyfftw
import numpy as np

from hysop import (
    __FFTW_NUM_THREADS__,
    __FFTW_PLANNER_EFFORT__,
    __FFTW_PLANNER_TIMELIMIT__,
    __VERBOSE__,
)
from hysop.tools.io_utils import IO
from hysop.tools.htypes import first_not_None
from hysop.tools.misc import prod
from hysop.tools.string_utils import framed_str
from hysop.tools.cache import load_data_from_cache, update_cache
from hysop.numerics.fft.fft import HysopFFTWarning, bytes2str
from hysop.numerics.fft.host_fft import HostFFTPlanI, HostFFTI, HostArray


[docs] class FftwFFTPlan(HostFFTPlanI): """ Build and wraps a FFTW plan. Emit warnings when SIMD alignment is not used. Emit warnings when changing input and output alignment. """ __FFTW_USE_CACHE__ = True
[docs] @classmethod def cache_file(cls): _cache_dir = IO.cache_path() + "/numerics" _cache_file = _cache_dir + "/fftw_wisdom.pklz" return _cache_file
[docs] @classmethod def load_wisdom(cls, h): if cls.__FFTW_USE_CACHE__: wisdom = load_data_from_cache(cls.cache_file(), h) if wisdom is not None: pyfftw.import_wisdom(wisdom) return True return False
[docs] @classmethod def save_wisdom(cls, h, plan): if cls.__FFTW_USE_CACHE__: wisdom = pyfftw.export_wisdom() update_cache(cls.cache_file(), h, wisdom)
def __init__(self, a, out, scaling=None, **plan_kwds): verbose = plan_kwds.pop("verbose", __VERBOSE__) super().__init__(verbose=verbose) if isinstance(a, HostArray): plan_kwds["input_array"] = a.handle else: plan_kwds["input_array"] = a if isinstance(out, HostArray): plan_kwds["output_array"] = out.handle else: plan_kwds["output_array"] = out def fmt_arg(name): return plan_kwds[name] def fmt_array(name): arr = fmt_arg(name) return "shape={:<16} strides={:<16} dtype={:<16}".format( str(arr.shape) + ",", str(arr.strides) + ",", str(arr.dtype) ) title = f" Planning {self.__class__.__name__} " msg = """ in_array: {} out_array: {} axes: {} direction: {} threads: {} flags: {} planning timelimit: {}""".format( fmt_array("input_array"), fmt_array("output_array"), fmt_arg("axes"), fmt_arg("direction"), fmt_arg("threads"), " | ".join(fmt_arg("flags")), fmt_arg("planning_timelimit"), ) if self.verbose: print() print(framed_str(title, msg, c="*")) def hash_arg(name): return hash(plan_kwds[name]) def hash_array(name): arr = plan_kwds[name] return hash(arr.shape) ^ hash(arr.strides) # h = hash_array('input_array') ^ hash_array('output_array') ^ hash_arg('axes') ^ hash_arg('direction') h = None plan = None may_have_wisdom = self.load_wisdom(h) if may_have_wisdom: plan_kwds["flags"] += ("FFTW_WISDOM_ONLY",) # try to build plan from wisdom only (can fail if wisdom has only measure knowledge for example) try: plan = pyfftw.FFTW(**plan_kwds) except RuntimeError: pass if plan is None: plan_kwds["flags"] = tuple(set(plan_kwds["flags"]) - {"FFTW_WISDOM_ONLY"}) plan = pyfftw.FFTW(**plan_kwds) self.save_wisdom(h, plan) if not plan.simd_aligned: msg = "Resulting plan is not SIMD aligned ({} bytes boundary)." msg = msg.format(pyfftw.simd_alignment) warnings.warn(msg, HysopFFTWarning) self.plan = plan self.scaling = scaling self.out = out self.a = a @property def input_array(self): return self.a @property def output_array(self): return self.out
[docs] def check_new_inputs(self, a, out): plan = self.plan if (a is not None) and (not pyfftw.is_byte_aligned(a, n=plan.input_alignment)): msg = "New input array is not aligned on original planned input alignment of {} bytes." msg += "\nA copy will be made." msg = msg.format(plan.input_alignment) warnings.warn(msg, HysopFFTWarning) if (out is not None) and ( not pyfftw.is_byte_aligned(out, n=plan.output_alignment) ): msg = "New output array is not aligned on original planned output alignment of {} bytes." msg += "\nA copy will be made." msg = msg.format(plan.output_alignment) warnings.warn(msg, HysopFFTWarning) self.out = out self.a = a if isinstance(a, HostArray): a = a.handle if isinstance(out, HostArray): out = out.handle return (a, out)
[docs] def execute(self): """ Execute plan on current input and output array. """ self.plan.__call__() if self.scaling is not None: self.output_array[...] *= self.scaling
[docs] def __call__(self): """ Execute the plan on possibly different input and output arrays. Input array updates with arrays that are not aligned on original byte boundary will result in a copy being made. Return output array for convenience. """ self.execute()
[docs] class FftwFFT(HostFFTI): """ Interface to compute local to process FFT-like transforms using the FFTW backend. Fftw fft backend has many advantages: - single, double and long double precision supported - no intermediate temporary buffers created at each call. - planning capability with caching - multithreading capability Planning destroys initial arrays content. """ def __init__( self, threads=None, planner_effort=None, planning_timelimit=None, destroy_input=False, warn_on_misalignment=True, warn_on_allocation=True, error_on_allocation=False, backend=None, allocator=None, **kwds, ): threads = first_not_None(threads, __FFTW_NUM_THREADS__) planner_effort = first_not_None(planner_effort, __FFTW_PLANNER_EFFORT__) planning_timelimit = first_not_None( planning_timelimit, __FFTW_PLANNER_TIMELIMIT__ ) super().__init__( backend=backend, allocator=allocator, warn_on_allocation=warn_on_allocation, error_on_allocation=error_on_allocation, **kwds, ) self.supported_ftypes = (np.float32, np.float64, np.longdouble) self.supported_ctypes = (np.complex64, np.complex128, np.clongdouble) self.supported_cosine_transforms = (1, 2, 3, 4) self.supported_sine_transforms = (1, 2, 3, 4) self.threads = threads self.planner_effort = planner_effort self.planning_timelimit = planning_timelimit self.destroy_input = destroy_input self.warn_on_misalignment = warn_on_misalignment
[docs] @classmethod def check_alignment(cls, a, out): """Check SIMD alignment of input and output arrays.""" msg0 = "{} array is not aligned on SIMD aligment ({} bytes)." msg0 = msg0.format("{}", pyfftw.simd_alignment) if (a is not None) and not pyfftw.is_byte_aligned(array=a): msg = msg0.format("Input") warnings.warn(msg, HysopFFTWarning) elif (out is not None) and not pyfftw.is_byte_aligned(out): msg = msg0.format("Output") warnings.warn(msg, HysopFFTWarning)
[docs] def bake_kwds(self, **kwds): plan_kwds = {} plan_kwds["a"] = kwds.pop("a") plan_kwds["out"] = kwds.pop("out") plan_kwds["direction"] = kwds.pop("direction") plan_kwds["axes"] = kwds.pop("axes", (kwds.pop("axis"),)) plan_kwds["threads"] = kwds.pop("threads", self.threads) plan_kwds["verbose"] = kwds.pop("verbose", __VERBOSE__) plan_kwds["planning_timelimit"] = kwds.pop( "planning_timelimit", self.planning_timelimit ) flags = () flags += (kwds.pop("planner_effort", self.planner_effort),) if kwds.pop("destroy_input", self.destroy_input) is True: flags += ("FFTW_DESTROY_INPUT",) if kwds.pop("wisdom_only", False) is True: flags += ("FFTW_WISDOM_ONLY",) plan_kwds["flags"] = flags if kwds: msg = "Unknown keyword arguments: {}" msg = msg.format(", ".join(f"'{kwd}'" for kwd in kwds.keys())) raise RuntimeError(msg) return plan_kwds
[docs] def fft(self, a, out=None, axis=-1, **kwds): """Planning destroys initial arrays content.""" (shape, dtype) = super().fft(a=a, out=out, axis=axis, **kwds) out = self.allocate_output(out, shape, dtype) if self.warn_on_misalignment: self.check_alignment(a, out) kwds = self.bake_kwds(a=a, out=out, axis=axis, direction="FFTW_FORWARD", **kwds) plan = FftwFFTPlan(**kwds) return plan
[docs] def ifft(self, a, out=None, axis=-1, **kwds): """Planning destroys initial arrays content.""" (shape, dtype, s) = super().ifft(a=a, out=out, axis=axis, **kwds) out = self.allocate_output(out, shape, dtype) if self.warn_on_misalignment: self.check_alignment(a, out) kwds = self.bake_kwds( a=a, out=out, axis=axis, direction="FFTW_BACKWARD", **kwds ) plan = FftwFFTPlan(**kwds) return plan
[docs] def rfft(self, a, out=None, axis=-1, **kwds): """Planning destroys initial arrays content.""" (shape, dtype) = super().rfft(a=a, out=out, axis=axis, **kwds) out = self.allocate_output(out, shape, dtype) if self.warn_on_misalignment: self.check_alignment(a, out) kwds = self.bake_kwds(a=a, out=out, axis=axis, direction="FFTW_FORWARD", **kwds) plan = FftwFFTPlan(**kwds) return plan
[docs] def irfft(self, a, out=None, n=None, axis=-1, **kwds): """Planning destroys initial arrays content.""" (shape, dtype, s) = super().irfft(a=a, out=out, axis=axis, n=n, **kwds) out = self.allocate_output(out, shape, dtype) if self.warn_on_misalignment: self.check_alignment(a, out) kwds = self.bake_kwds( a=a, out=out, axis=axis, direction="FFTW_BACKWARD", **kwds ) plan = FftwFFTPlan(**kwds) return plan
[docs] def dct(self, a, out=None, type=2, axis=-1, **kwds): """Planning destroys initial arrays content.""" (shape, dtype) = super().dct(a=a, out=out, type=type, axis=axis, **kwds) out = self.allocate_output(out, shape, dtype) if self.warn_on_misalignment: self.check_alignment(a, out) dct_types = ["FFTW_REDFT00", "FFTW_REDFT10", "FFTW_REDFT01", "FFTW_REDFT11"] direction = dct_types[int(type) - 1] kwds = self.bake_kwds(a=a, out=out, axis=axis, direction=direction, **kwds) plan = FftwFFTPlan(**kwds) return plan
[docs] def idct(self, a, out=None, type=2, axis=-1, scaling=None, **kwds): """Planning destroys initial arrays content.""" (shape, dtype, itype, s) = super().idct( a=a, out=out, type=type, axis=axis, **kwds ) scaling = first_not_None(scaling, 1.0 / s) out = self.allocate_output(out, shape, dtype) if self.warn_on_misalignment: self.check_alignment(a, out) dct_types = ["FFTW_REDFT00", "FFTW_REDFT10", "FFTW_REDFT01", "FFTW_REDFT11"] direction = dct_types[int(itype) - 1] kwds = self.bake_kwds(a=a, out=out, axis=axis, direction=direction, **kwds) plan = FftwFFTPlan(scaling=scaling, **kwds) return plan
[docs] def dst(self, a, out=None, type=2, axis=-1, **kwds): """Planning destroys initial arrays content.""" (shape, dtype) = super().dst(a=a, out=out, type=type, axis=axis, **kwds) out = self.allocate_output(out, shape, dtype) if self.warn_on_misalignment: self.check_alignment(a, out) dst_types = ["FFTW_RODFT00", "FFTW_RODFT10", "FFTW_RODFT01", "FFTW_RODFT11"] direction = dst_types[int(type) - 1] kwds = self.bake_kwds(a=a, out=out, axis=axis, direction=direction, **kwds) plan = FftwFFTPlan(**kwds) return plan
[docs] def idst(self, a, out=None, type=2, axis=-1, scaling=None, **kwds): """Planning destroys initial arrays content.""" (shape, dtype, itype, s) = super().idst( a=a, out=out, type=type, axis=axis, **kwds ) scaling = first_not_None(scaling, 1.0 / s) out = self.allocate_output(out, shape, dtype) if self.warn_on_misalignment: self.check_alignment(a, out) dst_types = ["FFTW_RODFT00", "FFTW_RODFT10", "FFTW_RODFT01", "FFTW_RODFT11"] direction = dst_types[int(itype) - 1] kwds = self.bake_kwds(a=a, out=out, axis=axis, direction=direction, **kwds) plan = FftwFFTPlan(scaling=scaling, **kwds) return plan